{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91dcbbd9-49ed-40fd-a9f8-4b3964634136",
   "metadata": {},
   "outputs": [],
   "source": [
    "#0) Install + Login + Load LLaMA-3.1-8B-Instruct\n",
    "\n",
    "# --- Install dependencies ---\n",
    "!pip -q install -U transformers accelerate bitsandbytes sentencepiece huggingface_hub openpyxl tqdm\n",
    "\n",
    "# --- Hugging Face login (required for Meta LLaMA repositories in many cases) ---\n",
    "from huggingface_hub import login\n",
    "login()  # will prompt you to paste your HF token\n",
    "\n",
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig\n",
    "\n",
    "MODEL_ID = \"meta-llama/Llama-3.1-8B-Instruct\"  # adjust if your repo name differs\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    MODEL_ID,\n",
    "    device_map=\"auto\",\n",
    "    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,\n",
    "    load_in_4bit=True,  # memory-friendly on Colab GPUs\n",
    ")\n",
    "\n",
    "# Conservative generation settings for classification stability\n",
    "GEN_CFG = GenerationConfig(\n",
    "    max_new_tokens=220,\n",
    "    temperature=0.2,\n",
    "    top_p=0.9,\n",
    "    do_sample=True,\n",
    ")\n",
    "\n",
    "#1) Prompt Template + JSON Extraction\n",
    "\n",
    "import json\n",
    "import re\n",
    "\n",
    "PROMPT_TEMPLATE = \"\"\"\n",
    "You are coding a news article for a political communication study.\n",
    "\n",
    "Task A: Threat Object\n",
    "Which external actor is primarily framed as a threat in this article?\n",
    "1 = North Korea\n",
    "2 = China\n",
    "3 = Both North Korea and China\n",
    "4 = Neither / not framed as a threat\n",
    "\n",
    "Task B: Affective Orientation\n",
    "Which evaluative orientation best characterizes the threat narrative?\n",
    "1 = Crisis-oriented (episodic danger, immediacy, emergency response logic)\n",
    "2 = Grievance-oriented (unfairness, intrusion, long-term harm, responsibility attribution)\n",
    "3 = Mixed\n",
    "4 = Low affect / descriptive\n",
    "\n",
    "Task C: Evidence\n",
    "Provide up to two short excerpts (maximum 25 words each) that justify the classification in Task B.\n",
    "\n",
    "Return output in the following JSON format only:\n",
    "{{\n",
    "  \"threat_object\": <1-4>,\n",
    "  \"affective_orientation\": <1-4>,\n",
    "  \"evidence\": [\"...\", \"...\"],\n",
    "  \"confidence\": <0-100>\n",
    "}}\n",
    "\n",
    "Article text:\n",
    "<<<\n",
    "{ARTICLE_TEXT}\n",
    ">>>\n",
    "\"\"\".strip()\n",
    "\n",
    "def extract_json_object(text: str):\n",
    "    \"\"\"\n",
    "    Robustly extract the first JSON object found in a text block.\n",
    "    Returns dict or None.\n",
    "    \"\"\"\n",
    "    m = re.search(r\"\\{.*\\}\", text, flags=re.S)\n",
    "    if not m:\n",
    "        return None\n",
    "    try:\n",
    "        return json.loads(m.group(0))\n",
    "    except json.JSONDecodeError:\n",
    "        return None\n",
    "\n",
    "#2) Token-Based Segmentation\n",
    "\n",
    "def chunk_text_by_tokens(text: str, max_input_tokens: int = 900, overlap_tokens: int = 80):\n",
    "    \"\"\"\n",
    "    Token-based chunking for long news articles.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    text : str\n",
    "        Full article text.\n",
    "    max_input_tokens : int\n",
    "        Max tokens per chunk for the article text (excluding prompt tokens).\n",
    "    overlap_tokens : int\n",
    "        Overlap between consecutive chunks to reduce boundary effects.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    list[str]\n",
    "        List of chunk texts.\n",
    "    \"\"\"\n",
    "    if not isinstance(text, str) or not text.strip():\n",
    "        return []\n",
    "\n",
    "    token_ids = tokenizer.encode(text, add_special_tokens=False)\n",
    "    chunks = []\n",
    "\n",
    "    start = 0\n",
    "    n = len(token_ids)\n",
    "    while start < n:\n",
    "        end = min(start + max_input_tokens, n)\n",
    "        chunk_ids = token_ids[start:end]\n",
    "        chunk_text = tokenizer.decode(chunk_ids, skip_special_tokens=True)\n",
    "        chunks.append(chunk_text)\n",
    "\n",
    "        if end == n:\n",
    "            break\n",
    "        start = max(0, end - overlap_tokens)\n",
    "\n",
    "    return chunks\n",
    "\n",
    "#3) Segment-Level Classification Function\n",
    "\n",
    "def classify_segment(seg_text: str):\n",
    "    \"\"\"\n",
    "    Run one segment through the model and return:\n",
    "    - raw completion text\n",
    "    - parsed JSON dict (or None)\n",
    "    \"\"\"\n",
    "    prompt = PROMPT_TEMPLATE.format(ARTICLE_TEXT=seg_text)\n",
    "    inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        out_ids = model.generate(**inputs, generation_config=GEN_CFG)\n",
    "\n",
    "    decoded = tokenizer.decode(out_ids[0], skip_special_tokens=True)\n",
    "\n",
    "    # Remove prompt prefix if the model echoed it\n",
    "    completion = decoded[len(prompt):].strip() if decoded.startswith(prompt) else decoded.strip()\n",
    "\n",
    "    parsed = extract_json_object(completion)\n",
    "    return completion, parsed\n",
    "\n",
    "#4) Modal Aggregation + Confidence Tie-Break (Article-Level)\n",
    "\n",
    "import numpy as np\n",
    "from collections import Counter, defaultdict\n",
    "\n",
    "def _safe_conf(x):\n",
    "    try:\n",
    "        return float(x)\n",
    "    except Exception:\n",
    "        return 0.0\n",
    "\n",
    "def modal_with_confidence_tiebreak(values, confidences):\n",
    "    \"\"\"\n",
    "    Pick modal value; if tie, pick the category with the larger summed confidence.\n",
    "    \"\"\"\n",
    "    counts = Counter(values)\n",
    "    top_count = max(counts.values())\n",
    "    candidates = [k for k, v in counts.items() if v == top_count]\n",
    "    if len(candidates) == 1:\n",
    "        return candidates[0]\n",
    "\n",
    "    conf_sum = defaultdict(float)\n",
    "    for v, c in zip(values, confidences):\n",
    "        conf_sum[v] += c\n",
    "    candidates = sorted(candidates, key=lambda x: conf_sum[x], reverse=True)\n",
    "    return candidates[0]\n",
    "\n",
    "def aggregate_article(segment_outputs):\n",
    "    \"\"\"\n",
    "    Aggregate segment-level outputs to an article-level label.\n",
    "\n",
    "    segment_outputs: list[dict] where each dict contains:\n",
    "      - \"parsed\": dict or None\n",
    "      - \"raw\": raw model output string\n",
    "    \"\"\"\n",
    "    valid = [s for s in segment_outputs if s.get(\"parsed\") is not None]\n",
    "    if not valid:\n",
    "        return {\n",
    "            \"threat_object\": None,\n",
    "            \"affective_orientation\": None,\n",
    "            \"confidence\": None,\n",
    "            \"evidence\": None,\n",
    "            \"n_segments\": len(segment_outputs),\n",
    "            \"n_valid_segments\": 0\n",
    "        }\n",
    "\n",
    "    threat_objs = [v[\"parsed\"].get(\"threat_object\") for v in valid]\n",
    "    orientations = [v[\"parsed\"].get(\"affective_orientation\") for v in valid]\n",
    "    confs = [_safe_conf(v[\"parsed\"].get(\"confidence\")) for v in valid]\n",
    "\n",
    "    # Modal + confidence tie-break\n",
    "    threat_object = modal_with_confidence_tiebreak(threat_objs, confs)\n",
    "    affective_orientation = modal_with_confidence_tiebreak(orientations, confs)\n",
    "\n",
    "    # Article-level confidence: mean confidence among segments matching chosen orientation\n",
    "    matched_confs = [c for v, c in zip(valid, confs)\n",
    "                     if v[\"parsed\"].get(\"affective_orientation\") == affective_orientation]\n",
    "    article_conf = float(np.mean(matched_confs)) if matched_confs else float(np.mean(confs))\n",
    "\n",
    "    # Evidence: take from highest-confidence segments consistent with chosen orientation\n",
    "    evidence_pool = []\n",
    "    for v, c in zip(valid, confs):\n",
    "        if v[\"parsed\"].get(\"affective_orientation\") == affective_orientation:\n",
    "            ev = v[\"parsed\"].get(\"evidence\")\n",
    "            if isinstance(ev, list) and ev:\n",
    "                evidence_pool.append((c, ev))\n",
    "\n",
    "    evidence_pool.sort(key=lambda x: x[0], reverse=True)\n",
    "\n",
    "    chosen = []\n",
    "    for _, ev_list in evidence_pool[:2]:\n",
    "        for e in ev_list:\n",
    "            if isinstance(e, str) and e.strip():\n",
    "                chosen.append(e.strip())\n",
    "        if len(chosen) >= 2:\n",
    "            break\n",
    "    chosen = chosen[:2] if chosen else None\n",
    "\n",
    "    return {\n",
    "        \"threat_object\": threat_object,\n",
    "        \"affective_orientation\": affective_orientation,\n",
    "        \"confidence\": round(article_conf, 2),\n",
    "        \"evidence\": chosen,\n",
    "        \"n_segments\": len(segment_outputs),\n",
    "        \"n_valid_segments\": len(valid)\n",
    "    }\n",
    "\n",
    "#5) Excel Upload → Batch Inference → Excel Save\n",
    "#This block:\n",
    "•\tuploads an Excel file\n",
    "•\texpects an article text column (default: \"contents\")\n",
    "•\tclassifies each row\n",
    "•\twrites results to a new Excel and downloads it\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "from google.colab import files\n",
    "\n",
    "# --- Upload Excel ---\n",
    "uploaded = files.upload()\n",
    "in_file = list(uploaded.keys())[0]\n",
    "df = pd.read_excel(in_file)\n",
    "\n",
    "# --- Set your text column name here ---\n",
    "TEXT_COL = \"contents\"  # change if needed\n",
    "df[TEXT_COL] = df[TEXT_COL].astype(str)\n",
    "\n",
    "# --- Chunking parameters (safe defaults) ---\n",
    "MAX_INPUT_TOKENS = 900\n",
    "OVERLAP_TOKENS = 80\n",
    "\n",
    "results = []\n",
    "\n",
    "for i in tqdm(range(len(df))):\n",
    "    text = df.loc[i, TEXT_COL]\n",
    "\n",
    "    chunks = chunk_text_by_tokens(\n",
    "        text,\n",
    "        max_input_tokens=MAX_INPUT_TOKENS,\n",
    "        overlap_tokens=OVERLAP_TOKENS\n",
    "    )\n",
    "\n",
    "    seg_outputs = []\n",
    "    for seg in chunks:\n",
    "        raw, parsed = classify_segment(seg)\n",
    "        seg_outputs.append({\"raw\": raw, \"parsed\": parsed})\n",
    "\n",
    "    agg = aggregate_article(seg_outputs)\n",
    "\n",
    "    results.append({\n",
    "        \"threat_object\": agg[\"threat_object\"],\n",
    "        \"affective_orientation\": agg[\"affective_orientation\"],\n",
    "        \"confidence\": agg[\"confidence\"],\n",
    "        \"evidence_1\": (agg[\"evidence\"][0] if agg[\"evidence\"] else None),\n",
    "        \"evidence_2\": (agg[\"evidence\"][1] if agg[\"evidence\"] else None),\n",
    "        \"n_segments\": agg[\"n_segments\"],\n",
    "        \"n_valid_segments\": agg[\"n_valid_segments\"]\n",
    "    })\n",
    "\n",
    "df_out = pd.concat([df.reset_index(drop=True), pd.DataFrame(results)], axis=1)\n",
    "\n",
    "# --- Save to Excel ---\n",
    "out_file = \"classified_output.xlsx\"\n",
    "df_out.to_excel(out_file, index=False)\n",
    "\n",
    "files.download(out_file)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
